(ns cortex.metrics-test
  (:require
    #?(:cljs [cljs.test :refer-macros [deftest is testing]]
       :clj [clojure.test :refer [deftest is testing run-tests]])
     [cortex.metrics :as metrics]
     [cortex.loss.yolo2 :as yolo-loss]))


(deftest test-wrongs
  (testing "small example"
    (is (= [0 0 1] (metrics/wrongs [1 1 1] [1 1 0])))))

(deftest error-rate
  (testing "small example"
    (is (= 0.25 (metrics/error-rate [1 1 1 1] [1 1 1 0])))))

(deftest accuracy
  (testing "small example"
    (is (= 0.5 (metrics/accuracy [1 0 1 0] [1 1 0 0])))))

(deftest false-negatives
  (testing "small example"
    (is (= [0 1 0 0] (metrics/false-negatives [1 1 0 0] [1 0 1 0])))))

(deftest false-positives
  (testing "small example"
    (is (= [0 0 1 0] (metrics/false-positives [1 1 0 0] [1 0 1 0])))))

(deftest true-negatives
  (testing "small example"
    (is (= [0 0 0 1] (metrics/true-negatives [1 1 0 0] [1 0 1 0])))))

(deftest true-positives
  (testing "small example"
    (is (= [1 0 0 0] (metrics/true-positives [1 1 0 0] [1 0 1 0])))))

(deftest global-metrics
  (testing "Calculation of global metrics for a localization-with-classification problem.
           Test data is generated by running a yolo style model on a 4x4 grid of 36x36 grid cells, containing mnist digits randomly placed.
           The bounding box coordinate space is normalized to range [0,1], using the x-y-w-h coords.
           x-y values are incremented on each new data point to avoid overlap during iou calculation
           (thus, from reading the labels below we see the first image contained only an '8', the second contained a '7' and a '2', etc...).
           Also note that an mnist digit, being 28x28, has width and height of ~0.194."
    (let [labels      [{:class "8", :bounding-box [0.7291666666666667 0.20833333333333334 0.19444444444444453 0.19444444444444448]}
                       {:class "7", :bounding-box [1.7986111111111112 1.7916666666666665 0.19444444444444442 0.19444444444444442]}
                       {:class "2", :bounding-box [1.4166666666666665 1.4097222222222223 0.19444444444444442 0.19444444444444442]}
                       {:class "5", :bounding-box [2.3819444444444446 2.7430555555555554 0.1944444444444442 0.19444444444444464]}
                       {:class "9", :bounding-box [2.4722222222222223 2.7986111111111107 0.19444444444444464 0.1944444444444442]}
                       {:class "6", :bounding-box [3.6875 3.1805555555555554 0.19444444444444464 0.1944444444444442]}
                       {:class "0", :bounding-box [3.4027777777777777 3.6527777777777777 0.19444444444444464 0.19444444444444464]}
                       {:class "6", :bounding-box [4.4375 4.180555555555555 0.19444444444444464 0.19444444444444464]}
                       {:class "7", :bounding-box [4.104166666666667 4.604166666666667 0.19444444444444464 0.19444444444444464]}
                       {:class "9", :bounding-box [5.409722222222222 5.319444444444445 0.19444444444444464 0.19444444444444464]}
                       {:class "9", :bounding-box [5.291666666666667 5.5625 0.19444444444444464 0.19444444444444464]}
                       {:class "5", :bounding-box [6.326388888888889 6.694444444444445 0.19444444444444375 0.19444444444444464]}
                       {:class "9", :bounding-box [6.798611111111111 6.145833333333333 0.19444444444444375 0.19444444444444464]}
                       {:class "4", :bounding-box [6.819444444444445 6.472222222222222 0.19444444444444464 0.19444444444444464]}
                       {:class "0", :bounding-box [7.527777777777778 7.305555555555555 0.19444444444444464 0.19444444444444464]}
                       {:class "7", :bounding-box [8.45138888888889 8.11111111111111 0.19444444444444464 0.19444444444444464]}
                       {:class "6", :bounding-box [9.23611111111111 9.131944444444443 0.19444444444444464 0.19444444444444464]}]
          predictions [{:class "9", :bounding-box [0.7415400688897913 0.2071927432288302 0.1921456577967493 0.1933240591909967]}
                       {:class "7", :bounding-box [1.8015478804671057 1.8259636791157967 0.13996132041092157 0.13248965356749265]}
                       {:class "7", :bounding-box [1.857401351001621 1.7687411550458605 0.15562700178104638 0.16897209563507243]}
                       {:class "2", :bounding-box [1.309957786791823 1.4555298173843845 0.24309233912097028 0.23836555764457312]}
                       {:class "1", :bounding-box [2.4520679177804836 2.7798956148073133 0.16038927709573958 0.16710792346175296]}
                       {:class "0", :bounding-box [2.4852839045041906 2.780678884030566 0.1667354214036476 0.1626375764006931]}
                       {:class "6", :bounding-box [3.427538368680848 3.646641000700763 0.19297170719412327 0.1939054608482529]}
                       {:class "9", :bounding-box [4.4347589118130255 4.207270826039748 0.16990286136074229 0.175877573242059]}
                       {:class "1", :bounding-box [4.163709496636373 4.5370379054581385 0.18581208740997113 0.18525786768598795]}
                       {:class "9", :bounding-box [5.383778156538428 5.362255156091932 0.20387371024628198 0.2050466843568728]}
                       {:class "6", :bounding-box [6.761208820454091 6.2427644742263535 0.12953479878296115 0.1334987331127726]}
                       {:class "5", :bounding-box [6.840611937031372 6.402141002468932 0.18125110077847317 0.16203213612631817]}
                       {:class "5", :bounding-box [6.7858612403469305 6.415196094234717 0.103248258448974 0.09908673420118319]}
                       {:class "4", :bounding-box [6.343449833878533 6.586190654919882 0.17763392610752682 0.1705477950934986]}
                       {:class "4", :bounding-box [6.290565516462699 6.729891428339311 0.1144518728068471 0.11233165686025881]}
                       {:class "0", :bounding-box [6.759644136431401 6.249168334055671 0.19217943499481738 0.2281111386756418]}
                       {:class "4", :bounding-box [7.536159182917409 7.316686987995962 0.1827497097208859 0.17869439865881098]}
                       {:class "1", :bounding-box [7.513592056065292 7.293438534965601 0.19600594293550877 0.19676255651382402]}
                       {:class "6", :bounding-box [9.21977443565332 9.151393693085293 0.19741599531089093 0.19656574252544523]}]
          metrics     (metrics/all-metrics labels predictions :class #(yolo-loss/iou (:bounding-box %1) (:bounding-box %2)) 0.2)]
      (is (= metrics {:global-metrics    {:location-sensitivity    0.823529411764706,
                                          :location-precision      0.7368421052631579,
                                          :location-F1             0.7777777777777778,
                                          :classification-accuracy 0.2857142857142857,
                                          :global-F1               0.2222222222222222},
                      :per-class-metrics `({:class "0", :location-sensitivity 0.0, :location-precision 0.0, :location-F1 0.0}
                                            {:class "2", :location-sensitivity 1.0, :location-precision 1.0, :location-F1 1.0}
                                            {:class "4", :location-sensitivity 0.0, :location-precision 0.0, :location-F1 0.0}
                                            {:class "5", :location-sensitivity 0.0, :location-precision 0.0, :location-F1 0.0}
                                            {:class "6", :location-sensitivity 0.3333333333333333, :location-precision 0.3333333333333333, :location-F1 0.3333333333333333}
                                            {:class "7", :location-sensitivity 0.3333333333333333, :location-precision 0.5, :location-F1 0.4}
                                            {:class "8", :location-sensitivity 0.0, :location-precision 1.0, :location-F1 0.0}
                                            {:class "9", :location-sensitivity 0.25, :location-precision 0.3333333333333333, :location-F1 0.2857142857142857})})))))
